{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook is about\n",
    "\n",
    "# How we can censor our texts in dialogue system\n",
    "We detect the following topics (for russian language):\n",
    "* Суицид;\n",
    "* Насилие\n",
    "* Секс\n",
    "* ЛГБТ\n",
    "* Наркотики\n",
    "* Терроризм\n",
    "* Политика\n",
    "\n",
    "The train and dev data contains our own marked up dialogues (by hands) - 1 and marked up OpenSubTitles and LentaRU by matching of obscene words corpus. Test data contain other utterances marked up by our hands -2 (Note, that our data 1 and data 2 has very different styles and distributions if you want).\n",
    "\n",
    "PS. For more detailes about data write me dirrectly to email."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys\n",
    "import warnings\n",
    "\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test MultiBERT\n",
    "### PoolingLinearClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data import bert_data_clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df_path = \"/home/ubuntu/censor/train.csv\"\n",
    "valid_df_path = \"/home/ubuntu/censor/dev.csv\"\n",
    "test_df_path = \"/home/ubuntu/censor/test.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=21380, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "data = bert_data_clf.LearnDataClass.create(\n",
    "    train_df_path=train_df_path,\n",
    "    valid_df_path=valid_df_path,\n",
    "    idx2cls_path=\"/home/ubuntu/censor/idx2cls.txt\",\n",
    "    clear_cache=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.models.classifiers import BERTBaseClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BERTBaseClassifier.create(len(data.train_ds.cls2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.train.train_clf import NerLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'1': 0, '0': 1}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds.cls2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = NerLearner(\n",
    "    model, data, \"/home/ubuntu/censor/cls.cpt\", t_total=num_epochs * len(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "885903"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_n_trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learner.fit(epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data_clf import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=test_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=168, style=ProgressStyle(description_width='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9804    0.9903    0.9853      2471\n",
      "           1     0.8644    0.7574    0.8074       202\n",
      "\n",
      "   micro avg     0.9727    0.9727    0.9727      2673\n",
      "   macro avg     0.9224    0.8739    0.8963      2673\n",
      "weighted avg     0.9716    0.9727    0.9719      2673\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERTBiLSTMAttnClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data import bert_data_clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df_path = \"/home/ubuntu/censor/train.csv\"\n",
    "valid_df_path = \"/home/ubuntu/censor/dev.csv\"\n",
    "test_df_path = \"/home/ubuntu/censor/test.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
     ]
    }
   ],
   "source": [
    "data = bert_data_clf.LearnDataClass.create(\n",
    "    train_df_path=train_df_path,\n",
    "    valid_df_path=valid_df_path,\n",
    "    idx2cls_path=\"/home/ubuntu/censor/idx2cls.txt\",\n",
    "    clear_cache=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.models.classifiers import BERTBiLSTMAttnClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BERTBiLSTMAttnClassifier.create(len(data.train_ds.cls2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.train.train_clf import NerLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'1': 0, '0': 1}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds.cls2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = NerLearner(\n",
    "    model, data, \"/home/ubuntu/censor/cls.cpt2\", t_total=num_epochs * len(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2889999"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_n_trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learner.fit(epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data_clf import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=test_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=168, style=ProgressStyle(description_width='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9899    0.9923    0.9911      2471\n",
      "           1     0.9031    0.8762    0.8894       202\n",
      "\n",
      "   micro avg     0.9835    0.9835    0.9835      2673\n",
      "   macro avg     0.9465    0.9343    0.9403      2673\n",
      "weighted avg     0.9833    0.9835    0.9834      2673\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERTBiLSTMAttnClassifier (add open subtitles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data import bert_data_clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df_path = \"/home/ubuntu/censor/train.csv\"\n",
    "valid_df_path = \"/home/ubuntu/censor/dev.csv\"\n",
    "test_df_path = \"/home/ubuntu/censor/test2.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pytorch_pretrained_bert.tokenization:The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n",
      "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt from cache at /home/ubuntu/.pytorch_pretrained_bert/96435fa287fbf7e469185f1062386e05a075cadbf6838b74da22bf64b080bc32.99bcd55fc66f4f3360bc49ba472b940b8dcf223ea6a345deb969d607ca900729\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=7159232, style=ProgressStyle(des…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "data = bert_data_clf.LearnDataClass.create(\n",
    "    train_df_path=train_df_path,\n",
    "    valid_df_path=None,\n",
    "    idx2cls_path=\"/home/ubuntu/censor/idx2cls.txt\",\n",
    "    clear_cache=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.models.classifiers import BERTBiLSTMAttnClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz from cache at /home/ubuntu/.pytorch_pretrained_bert/731c19ddf94e294e00ec1ba9a930c69cc2a0fd489b25d3d691373fae4c0986bd.4e367b0d0155d801930846bb6ed98f8a7c23e0ded37888b29caa37009a40c7b9\n",
      "INFO:pytorch_pretrained_bert.modeling:extracting archive file /home/ubuntu/.pytorch_pretrained_bert/731c19ddf94e294e00ec1ba9a930c69cc2a0fd489b25d3d691373fae4c0986bd.4e367b0d0155d801930846bb6ed98f8a7c23e0ded37888b29caa37009a40c7b9 to temp dir /tmp/tmpa2vfzwcs\n",
      "INFO:pytorch_pretrained_bert.modeling:Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 119547\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = BERTBiLSTMAttnClassifier.create(len(data.train_ds.cls2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.train.train_clf import NerLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'0': 0, '1': 1}"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds.cls2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = NerLearner(\n",
    "    model, data, \"/home/ubuntu/censor/cls.cpt3\", t_total=num_epochs * 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2889999"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_n_trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learner.fit(epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data_clf import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### on valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=valid_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9960    0.9903    0.9931      1243\n",
      "           1     0.9941    0.9975    0.9958      2031\n",
      "\n",
      "   micro avg     0.9948    0.9948    0.9948      3274\n",
      "   macro avg     0.9950    0.9939    0.9945      3274\n",
      "weighted avg     0.9948    0.9948    0.9948      3274\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### on test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=test_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=66, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.5550    0.4436    0.4931       523\n",
      "           1     0.5366    0.6444    0.5856       523\n",
      "\n",
      "   micro avg     0.5440    0.5440    0.5440      1046\n",
      "   macro avg     0.5458    0.5440    0.5393      1046\n",
      "weighted avg     0.5458    0.5440    0.5393      1046\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors = []\n",
    "for pred, feature in zip(preds, ds):\n",
    "    if pred != feature.cls:\n",
    "        errors.append({\n",
    "            \"text\": \" \".join(feature.tokens[1:-1]),\n",
    "            \"pred\": pred,\n",
    "            \"true\": feature.cls\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=1, style=ProgressStyle(description_width='in…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['0', '0', '0']"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.predict(df=pd.DataFrame(\n",
    "    {\n",
    "        \"text\": [\"Привет, пидор\", \"Путин гад\", \"что правда, гандон?\"], \"cls\": [\"1\", \"0\", \"0\"]\n",
    "}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### BERTBiLSTMAttnClassifier (add open subtitles and lenta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data import bert_data_clf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df_path = \"/home/ubuntu/censor/train_f.csv\"\n",
    "valid_df_path = \"/home/ubuntu/censor/dev_f.csv\"\n",
    "test_df_path = \"/home/ubuntu/censor/test2.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:pytorch_pretrained_bert.tokenization:The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n",
      "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt from cache at /home/ubuntu/.pytorch_pretrained_bert/96435fa287fbf7e469185f1062386e05a075cadbf6838b74da22bf64b080bc32.99bcd55fc66f4f3360bc49ba472b940b8dcf223ea6a345deb969d607ca900729\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=120917, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "data = bert_data_clf.LearnDataClass.create(\n",
    "    train_df_path=train_df_path,\n",
    "    valid_df_path=valid_df_path,\n",
    "    idx2cls_path=\"/home/ubuntu/censor/idx2cls.txt\",\n",
    "    clear_cache=True,\n",
    "    batch_size=64\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.models.classifiers import BERTBiLSTMAttnClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz from cache at /home/ubuntu/.pytorch_pretrained_bert/731c19ddf94e294e00ec1ba9a930c69cc2a0fd489b25d3d691373fae4c0986bd.4e367b0d0155d801930846bb6ed98f8a7c23e0ded37888b29caa37009a40c7b9\n",
      "INFO:pytorch_pretrained_bert.modeling:extracting archive file /home/ubuntu/.pytorch_pretrained_bert/731c19ddf94e294e00ec1ba9a930c69cc2a0fd489b25d3d691373fae4c0986bd.4e367b0d0155d801930846bb6ed98f8a7c23e0ded37888b29caa37009a40c7b9 to temp dir /tmp/tmpd6umgh5k\n",
      "INFO:pytorch_pretrained_bert.modeling:Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 119547\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model = BERTBiLSTMAttnClassifier.create(len(data.train_ds.cls2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.train.train_clf import NerLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'1': 0, '0': 1}"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.train_ds.cls2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = NerLearner(\n",
    "    model, data, \"/home/ubuntu/censor/cls_f.cpt\", t_total=num_epochs * len(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2889999"
      ]
     },
     "execution_count": 75,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_n_trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learner.fit(epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.load_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data_clf import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### on valid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=valid_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=210, style=ProgressStyle(description_width='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9952    0.9894    0.9923      5486\n",
      "           1     0.9927    0.9967    0.9947      7950\n",
      "\n",
      "   micro avg     0.9937    0.9937    0.9937     13436\n",
      "   macro avg     0.9940    0.9931    0.9935     13436\n",
      "weighted avg     0.9938    0.9937    0.9937     13436\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors = []\n",
    "for pred, feature in zip(preds, ds):\n",
    "    if pred != feature.cls:\n",
    "        errors.append({\n",
    "            \"text\": \" \".join(feature.tokens[1:-1]),\n",
    "            \"pred\": pred,\n",
    "            \"true\": feature.cls\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=1, style=ProgressStyle(description_width='in…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "['1', '1', '1']"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "learner.predict(df=pd.DataFrame(\n",
    "    {\n",
    "        \"text\": ['дурацкий день', \"Путин гад\".lower(), \"что правда, гандон?\"], \"cls\": [\"0\", \"0\", \"0\"]\n",
    "}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### on test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl, ds = get_data_loader_for_predict(data, df_path=test_df_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=17, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_report = flat_classification_report([f.cls for f in ds], preds, digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0     0.9757    0.9962    0.9858       523\n",
      "           1     0.9961    0.9751    0.9855       523\n",
      "\n",
      "   micro avg     0.9857    0.9857    0.9857      1046\n",
      "   macro avg     0.9859    0.9857    0.9857      1046\n",
      "weighted avg     0.9859    0.9857    0.9857      1046\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cls_report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors = []\n",
    "for pred, feature in zip(preds, ds):\n",
    "    if pred != feature.cls:\n",
    "        errors.append({\n",
    "            \"text\": \" \".join(feature.tokens[1:-1]),\n",
    "            \"pred\": pred,\n",
    "            \"true\": feature.cls\n",
    "        })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.98565965583174"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 - len(errors)/len(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'text': 'есть чё?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'я квир', 'pred': '0', 'true': '1'},\n",
       " {'text': 'тебя возбуждают другие коты?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'что ты принимаешь?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'ты что куришь?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'посоветуй что-нибудь для расширения сознания',\n",
       "  'pred': '0',\n",
       "  'true': '1'},\n",
       " {'text': 'я колюсь', 'pred': '0', 'true': '1'},\n",
       " {'text': 'Чей Крым?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'ты голосуешь?', 'pred': '0', 'true': '1'},\n",
       " {'text': 'Гнать дуру', 'pred': '0', 'true': '1'},\n",
       " {'text': 'АУЕ ворам свободу', 'pred': '0', 'true': '1'},\n",
       " {'text': 'Смолить чернушку', 'pred': '0', 'true': '1'},\n",
       " {'text': 'майн кампф', 'pred': '0', 'true': '1'},\n",
       " {'text': 'По одной из версий, он мог умереть от наркотического опьянения.',\n",
       "  'pred': '1',\n",
       "  'true': '0'},\n",
       " {'text': 'На одном из них было написано, что за убийство черного кота игрока ненавидит весь мир, а на другом поступок Луиша Фигу был назван \"позором \"Интера\".',\n",
       "  'pred': '1',\n",
       "  'true': '0'}]"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "errors"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
